package com.uxsino.simo.collector.connections;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/***
 * manage tcp based sessions.
 * 
 * 
 *
 */
public class TcpSessionManager {
    private static Logger logger = LoggerFactory.getLogger(TcpSessionManager.class);

    private static class AddrStatus {
        public AtomicLong millisLastConnectAt = new AtomicLong(0);

        public AtomicInteger contiguousConnectErrorCount = new AtomicInteger(0);

        public AtomicLong millisContiguousConnectErrorSuspendStartedAt = new AtomicLong(0);
    }

    // last time stamp where session is known open
    private ConcurrentHashMap<InetSocketAddress, AddrStatus> accessTimes = new ConcurrentHashMap<>();

    private int contiguousConnectErrorTolerantTimes = 3;

    private long millisContiguousConnectErrorSuspend = 60000;

    // wait at least x seconds for between session to open
    private long millisWaitBetween = 100;

    // wait time for particular address
    public long getWaitMilisecondsFor(InetSocketAddress socketAddr) {
        // use same default value for now
        return millisWaitBetween;
    }

    /**
     * wait between sessions for a minimum time span.
     * 
     * @param host
     * @param port
     */

    private AddrStatus getAddrStatus(InetSocketAddress addr) {
        AddrStatus status = accessTimes.get(addr);

        if (status == null) {
            status = new AddrStatus();
            accessTimes.put(addr, status);
        }
        return status;
    }

    /**
     * wait a minimum time between connections
     * @param host
     * @param port
     */
    public void waitFor(String host, int port) {
        try {
            InetAddress addr = InetAddress.getByName(host);
            waitFor(new InetSocketAddress(addr, port));
        } catch (UnknownHostException e) {
            logger.error("unknown host:", e);
        }
    }

    /**
     * wait a minimum time between connections
     * 
     * @param socketAddr
     */
    public void waitFor(InetSocketAddress socketAddr) {
        long t1 = System.currentTimeMillis();
        AddrStatus status = getAddrStatus(socketAddr);
        long t0 = status.millisLastConnectAt.get();
        long timeout = t0 + getWaitMilisecondsFor(socketAddr) - t1;
        if (timeout > 0) {
            try {
                TimeUnit.MILLISECONDS.sleep(timeout);
            } catch (InterruptedException e) {
                // ignore interruption
            }
        }
        setMaxLong(status.millisLastConnectAt, System.currentTimeMillis());
    }

    /**
     * update the last connect time to now
     * 
     * @param host
     * @param port
     */
    public void updateAccessTime(String host, int port, boolean success) {
        try {
            InetAddress addr = InetAddress.getByName(host);
            updateAccessTime(new InetSocketAddress(addr, port), success);
        } catch (UnknownHostException e) {
            logger.error("unknown host:", e);
        }
    }

    /**
     * update the last connect time to now
     * 
     * @param socketAddr
     * @param success if this connection succeeded
     */
    public void updateAccessTime(InetSocketAddress socketAddr, boolean success) {
        AddrStatus status = getAddrStatus(socketAddr);
        setMaxLong(status.millisLastConnectAt, System.currentTimeMillis());
        if (success) {
            status.contiguousConnectErrorCount.set(0);
        } else {
            int count = status.contiguousConnectErrorCount.getAndIncrement();
            if (count > contiguousConnectErrorTolerantTimes) {
                // already in suspend
                return;
            }

            // start suspend
            status.millisContiguousConnectErrorSuspendStartedAt.set(System.currentTimeMillis());
        }
    }

    public boolean isInSuspend(String host, int port) {
        return isInSuspend(new InetSocketAddress(host, port));
    }

    public boolean isInSuspend(InetSocketAddress addr) {
        AddrStatus status = getAddrStatus(addr);
        if (status.contiguousConnectErrorCount.get() <= contiguousConnectErrorTolerantTimes) {
            return false;
        }

        long t0 = status.millisContiguousConnectErrorSuspendStartedAt.get();

        boolean inSuspend = System.currentTimeMillis() - t0 < millisContiguousConnectErrorSuspend;

        if (!inSuspend) {
            status.contiguousConnectErrorCount.set(0);
        }
        return inSuspend;
    }

    /**
     * set the minimum time span between two connections
     * @param defaultWaitMiliseconds
     */
    public void setDefaultWaitMiliseconds(long defaultWaitMiliseconds) {
        this.millisWaitBetween = defaultWaitMiliseconds;
    }

    public void setContiguousConnectErrorSuspendMilliseconds(long millsSuspend) {
        millisContiguousConnectErrorSuspend = millsSuspend;
    }

    public void setContiguousConnectErrorTolerantTime(int times) {
        contiguousConnectErrorTolerantTimes = times;
    }

    // set the value to max(old, new)
    private static void setMaxLong(AtomicLong l, long v) {
        long currentMax = l.get();
        if (currentMax >= v)
            return;

        // if comparing is not true, means other thread set this alreay and then
        // just leave it as is
        l.compareAndSet(currentMax, v);
    }

}
